import numpy as np

from toolkit.utils import agrid, markov_rouwenhorst, variance, broyden_solver
from toolkit.simple_block import simple
from toolkit.solved_block import solved
from toolkit.jacobian import get_G
from toolkit.nonlinear import td_solve
import ghh_plus

'''Part 1: Embed HA block in macro model'''


def taxes(T_rule, pi_e, T0, T1, tax0, tax1, tax_rule):
    """T_rule is (e_grid,), T0 and T1 are scalars."""
    T = T0 + T1 / np.sum(pi_e * T_rule) * T_rule
    tax = tax0 + tax1 * tax_rule
    return T, tax


household = ghh_plus.household.attach_hetinput(taxes)

'''Part 2: simple blocks'''


@simple
def production(L, Z, F):
    Y = Z * L - F
    return Y


@simple
def dividend(Y, w, L, pi, mup, kappap):
    price_adjustment_cost = mup / (mup - 1) / (2 * kappap) * np.log(1 + pi) ** 2 * Y
    div = Y - w * L - price_adjustment_cost
    return price_adjustment_cost, div


@solved(unknowns=['B'], targets=['debt'])
def debt_policy(G, B, B_ss, rho_B, G_ss):
    debt = (1 - rho_B) * B_ss + rho_B * (B(-1) + G - G_ss) - B
    return debt


# T stays constant, tax adjusts according to arbitrary incidence rule
@simple
def gov_budget1(rb, REV, B, G, T0):
    gov_budget_res = B + REV - (1 + rb) * B(-1) - G - T0
    return gov_budget_res


@simple
def gov_budget2(w, L, tax1, REV_ss):
    tax0 = REV_ss / w / L
    tax_all = tax0 + tax1  # for plotting in the special case of flat tax
    return tax0, tax_all


@simple
def mkt_clearing(p, A, B, L, NS, Y, C, G, price_adjustment_cost):
    asset_mkt = p + B - A
    labor_mkt = L - NS
    goods_mkt = C + G + price_adjustment_cost - Y
    return asset_mkt, labor_mkt, goods_mkt


@solved(unknowns=['p'], targets=['equity'])
def arbitrage(div, p, re):
    equity = div(+1) + p(+1) - p * (1 + re)
    return equity


@simple
def interest_rates(re, rb, div, p, pshare):
    rpost = pshare * (div + p) / p(-1) + (1 - pshare) * (1 + rb) - 1
    rsub = re
    return rpost, rsub


@simple
def real_bonds(re):
    rb = re(-1)
    return rb


@simple
def nominal_bonds(i, pi):
    rb = (1 + i(-1)) / (1 + pi) - 1
    return rb


@simple
def price_setting(pi, w, re, Y, Z, kappap, mup):
    nkpc = kappap * (w / Z - 1 / mup) + Y(+1) / Y * np.log(1 + pi(+1)) / (1 + re) - np.log(1 + pi)
    return nkpc


@simple
def monetary(i, pi):
    re = (1 + i) / (1 + pi(+1)) - 1
    return re


@simple
def taylor_rule(rstar, pi, phi):
    i = rstar + phi * pi
    return i


@simple
def constant_r(rstar, pi):
    i = rstar + pi(+1) * (1 + rstar)
    return i


'''Part 3: model variations'''


def sp_benchmark(ss, T, linear=True, dG=None):
    """Compute GE Jacobian of benchmark sticky-price model."""
    # set up DAG
    block_list = [household, production, dividend, debt_policy, gov_budget1, mkt_clearing,
                  arbitrage, interest_rates, real_bonds, constant_r, monetary, price_setting,
                  gov_budget2]
    exogenous = ['G']
    unknowns = ['pi', 'w', 'L', 'tax1']
    targets = ['asset_mkt', 'labor_mkt', 'nkpc', 'gov_budget_res']

    if not linear:
        if dG is None:
            ValueError('If you want a nonlinear solution, specify the shock dG.')

    if linear:
        return get_G(block_list, exogenous, unknowns, targets, T=T, ss=ss)
    else:
        return td_solve(ss, block_list, unknowns, targets, returnindividual=True, G=ss['G'] + dG)


def sp_taylor(ss, T):
    """Compute GE Jacobian of sticky-price model with Taylor rule."""
    # set up DAG
    block_list = [household, production, dividend, debt_policy, gov_budget1, mkt_clearing,
                  arbitrage, interest_rates, real_bonds, price_setting,
                  monetary, taylor_rule, gov_budget2]
    exogenous = ['G']
    unknowns = ['pi', 'w', 'L', 'tax1']
    targets = ['asset_mkt', 'labor_mkt', 'nkpc', 'gov_budget_res']

    return get_G(block_list, exogenous, unknowns, targets, T=T, ss=ss)


def sp_peg(ss, T, k=12):
    """Compute GE Jacobian of sticky-price model with nominal peg for k years followed by Taylor rule."""
    # compute Jacobian of Taylor rule
    J_taylor = taylor_rule.jac(ss, T)

    # no response to inflation in first k periods
    J_taylor['i']['pi'][:k, :] = 0

    # set up DAG
    block_list = [household, production, dividend, debt_policy, gov_budget1, mkt_clearing,
                  arbitrage, interest_rates, real_bonds, price_setting,
                  monetary, J_taylor, gov_budget2]
    exogenous = ['G']
    unknowns = ['pi', 'w', 'L', 'tax1']
    targets = ['asset_mkt', 'labor_mkt', 'nkpc', 'gov_budget_res']

    return get_G(block_list, exogenous, unknowns, targets, T=T, ss=ss)


'''Part 4: Calibration'''


def calibrate(betamax_guess=0.98, vphi_guess=0.86, sigma_guess=2.0, sd_e_guess=0.8, EIS_target=0.5, r=0.005, alpha=0.0,
              nu=2.0, mup=7 / 6, sd_income=0.92, B=0.55*4, rho_B=0.9, p=0.85*4, tax0=0.334, T_w=0.143,
              pi_beta=None, Pi_beta=None, rho_e=0.966, kappap=0.01, phi=1.25, nS=11, amax=1000, nA=500, T_rule=None,
              tax_rule=None, noisy=False, bmin=0.0, bmax=0.2, MPC_tol=0.001, MPC_target=0.25, maxit=10):
    """Solve steady state of full GE model.

    Inner loop: calibrate (betamax, vphi, sigma, sd_e, Z, F) to hit targets (r, p, EIS, sd_income, Y=1, L=1).
    Outer loop: calibrate range of beta heterogeneity to hit average MPC.
    """

    # set up grid
    a_grid = agrid(amax=amax, n=nA)
    if pi_beta is None:
        pi_beta = np.array([0.5, 0.5])
        Pi_beta = np.eye(2)

    # tax and transfer rule, scale does not matter, will be normalized anyway
    if T_rule is None:
        T_rule = np.ones(nS * len(pi_beta))
    if tax_rule is None:
        tax_rule = np.ones(nS * len(pi_beta))
    assert len(T_rule) == len(tax_rule) == (nS * len(pi_beta)), 'Incidence rules are inconsistent with income grid.'

    # solve analytically what we can
    Z = mup * (1 - p * r)
    F = Z - 1
    w = Z / mup
    div = 1 - w
    pshare = p / (p + B)
    T0 = T_w * w
    G = tax0 * w - r * B - T0  # uses L=1
    tax1, T1 = 0, 0
    T, tax = taxes(T_rule, 1, T0, T1, tax0, tax1, tax_rule)
    fininc = (1 + r) * a_grid[np.newaxis, :] - a_grid[0] + T[:, np.newaxis]
    atw = (1 - tax0) * w

    # initialize ss dict
    ss = {'B': B, 'phi': phi, 'kappap': kappap, 'Y': 1, 'rstar': r, 'i': r, 'mup': mup, 'L': 1, 'pi': 0, 'div': div,
          'G': G, 'G_ss': G, 'B_ss': B, 'r': r, 'rho_B': rho_B, 'price_adjustment_cost': 0, 'REV_ss': w * tax0,
          're': r, 'rb': r, 'pshare': pshare, 'p': p, 'Z': Z, 'F': F, 'tax_all': tax0}

    # residual function
    def res(brange0):
        # residual function conditional on betarange
        def res_inner(x):
            # guesses with local names
            betamax0, vphi0, sigma0, sd_e0 = x

            # update exogenous processes
            betas0 = np.array([betamax0 - brange0, betamax0])
            e0, pi_e0, Pi_e0 = markov_rouwenhorst(rho=rho_e, sigma=sd_e0, N=nS)
            e_grid0 = np.kron(np.ones_like(betas0), e0)
            beta_grid0 = np.kron(betas0, np.ones_like(e0))
            pi0 = np.kron(pi_beta, pi_e0)
            Pi0 = np.kron(Pi_beta, Pi_e0)

            # initial guess for policy function iteration
            n_ghh0 = (atw * e_grid0 / vphi0) ** (1 / nu)
            coh0 = (1 + r) * a_grid[np.newaxis, :] + atw * (e_grid0 * n_ghh0 + T)[:, np.newaxis]
            uc0 = (0.5 * coh0 - vphi0 * alpha * n_ghh0[:, np.newaxis] ** (1 + nu) / (1 + nu)) ** (-sigma0)
            c_const0, n_const0 = ghh_plus.solve_cn(atw * e_grid0[:, np.newaxis], fininc, sigma0, alpha, nu, vphi0, uc0)

            if betamax0 > 0.9999 / (1 + r) or betamax0 - brange0 < 0.1 or vphi0 < 0.01 or sigma0 < 0.01 or sd_e0 < 0.01:
                raise ValueError('Clearly invalid inputs')

            # solve HH problem
            out = household.ss(uc=uc0, Pi=Pi0, a_grid=a_grid, e_grid=e_grid0, beta_grid=beta_grid0, w=w,
                               sigma=sigma0, alpha=alpha, nu=nu, vphi=vphi0, T1=T1, T0=T0, T_rule=T_rule, pi_e=pi0,
                               tax0=tax0, tax_rule=tax_rule, tax1=tax1,
                               rpost=r, rsub=r, pi_seed=pi0, c_const=c_const0, n_const=n_const0, ssflag=True)

            # avg EIS and Comp
            eis0, comp0 = ghh_plus.elasticities(out['uc'], atw * e_grid0[:, np.newaxis], sigma0, alpha, nu, vphi0)
            EIS0 = np.vdot(out['D'], eis0)

            # income dispersion after labor supply decision
            sd_income0 = np.sqrt(variance(np.log(out['ns']), out['D']))

            # save all of this into ss
            ss.update({**out, 'eis': eis0, 'comp': comp0, 'EIS': EIS0, 'sd_income': sd_income,
                       'betamax': betamax0, 'vphi': vphi0, 'sd_e': sd_e0, 'brange': brange0})

            return np.array([out['A'] - B - p, out['NS'] - 1, EIS0 - EIS_target, sd_income0 - sd_income])

        # solve for params conditional on betarange
        print(f'Starting inner loop for betarange = {brange0}')
        x0 = np.array([betamax_guess, vphi_guess, sigma_guess, sd_e_guess])
        _ = broyden_solver(res_inner, x0, noisy=noisy)

        # MPCs
        mpc1, mpe1 = ghh_plus.mpcs_mpes(ss['c'], ss['a'], ss['tns'], a_grid, r, ss['eis'], ss['comp'], nu)
        MPC1 = np.vdot(ss['D'], mpc1)
        MPE1 = np.vdot(ss['D'], mpe1)
        MPC_resid = MPC1 - MPC_target
        print(f'MPC= {MPC1:0.3f}')

        ss.update({'mpc': mpc1, 'mpe': mpe1, 'MPC': MPC1, 'MPE': MPE1})

        return MPC_resid

    # bisect on betarange
    for it in range(maxit):
        bmid = (bmin + bmax) / 2
        error = res(bmid)
        if error > MPC_tol:
            # MPC too high
            bmax = bmid
        elif error < - MPC_tol:
            # MPC too low
            bmin = bmid
        else:
            break

    # check Walras's law
    walras = 1 - ss['C'] - G

    ss.update({'walras': walras, 'ssflag': False})

    return ss


def mpe_conversion(ss, dT=1E-4):
    """Convert quarterly to annual MPE."""
    transfer1 = np.full(4, ss['T0'])
    transfer1[0] += dT
    td1 = household.td(ss, T0=transfer1)
    iMPE1 = -(td1['TNS'] - ss['TNS']) / dT
    aMPE = np.sum(iMPE1)
    return aMPE
